{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"Tutorial on how to build a convnet w/ modern changes, e.g.\n",
    "Batch Normalization, Leaky rectifiers, and strided convolution.\n",
    "\n",
    "Parag K. Mital, Jan 2016.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%\n",
    "import tensorflow as tf\n",
    "from libs.batch_norm import batch_norm\n",
    "from libs.activations import lrelu\n",
    "from libs.connections import conv2d, linear\n",
    "from libs.datasets import MNIST\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Setup input to the network and true output label.  These are\n",
    "# simply placeholders which we'll fill in later.\n",
    "mnist = MNIST()\n",
    "x = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, 10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% We add a new type of placeholder to denote when we are training.\n",
    "# This will be used to change the way we compute the network during\n",
    "# training/testing.\n",
    "is_training = tf.placeholder(tf.bool, name='is_training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% We'll convert our MNIST vector data to a 4-D tensor:\n",
    "# N x W x H x C\n",
    "x_tensor = tf.reshape(x, [-1, 28, 28, 1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% We'll use a new method called  batch normalization.\n",
    "# This process attempts to \"reduce internal covariate shift\"\n",
    "# which is a fancy way of saying that it will normalize updates for each\n",
    "# batch using a smoothed version of the batch mean and variance\n",
    "# The original paper proposes using this before any nonlinearities\n",
    "h_1 = lrelu(batch_norm(conv2d(x_tensor, 32, name='conv1'),\n",
    "                       is_training, scope='bn1'), name='lrelu1')\n",
    "h_2 = lrelu(batch_norm(conv2d(h_1, 64, name='conv2'),\n",
    "                       is_training, scope='bn2'), name='lrelu2')\n",
    "h_3 = lrelu(batch_norm(conv2d(h_2, 64, name='conv3'),\n",
    "                       is_training, scope='bn3'), name='lrelu3')\n",
    "h_3_flat = tf.reshape(h_3, [-1, 64 * 4 * 4])\n",
    "h_4 = linear(h_3_flat, 10)\n",
    "y_pred = tf.nn.softmax(h_4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% Define loss/eval/training functions\n",
    "cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))\n",
    "train_step = tf.train.AdamOptimizer().minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% We now create a new session to actually perform the initialization the\n",
    "# variables:\n",
    "sess = tf.Session()\n",
    "sess.run(tf.initialize_all_variables())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %% We'll train in minibatches and report accuracy:\n",
    "n_epochs = 10\n",
    "batch_size = 100\n",
    "for epoch_i in range(n_epochs):\n",
    "    for batch_i in range(mnist.train.num_examples // batch_size):\n",
    "        batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n",
    "        sess.run(train_step, feed_dict={\n",
    "            x: batch_xs, y: batch_ys, is_training: True})\n",
    "    print(sess.run(accuracy,\n",
    "                   feed_dict={\n",
    "                       x: mnist.validation.images,\n",
    "                       y: mnist.validation.labels,\n",
    "                       is_training: False\n",
    "                   }))"
   ]
  }
 ],
 "metadata": {
  "language": "python"
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
